import  numpy as np
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Stack,Pad,Tuple

def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
    """
        Builds model inputs from a sequence or a pair of sequence for sequence classification tasks
        by concatenating and adding special tokens. And creates a mask from the two sequences passed
        to be used in a sequence-pair classification task.

        A BERT sequence has the following format:

        - single sequence: ``[CLS] X [SEP]``
        - pair of sequences: ``[CLS] A [SEP] B [SEP]``

        A BERT sequence pair mask has the following format:
        ::
            0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
            | first sequence    | second sequence |

        If only one sequence, only returns the first portion of the mask (0's).


        Args:
            example(obj:`list[str]`): List of input data, containing text and label if it have label.
            tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
                which contains most of the methods. Users should refer to the superclass for more information regarding methods.
            max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
                Sequences longer than this will be truncated, sequences shorter will be padded.
            is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.

        Returns:
            input_ids(obj:`list[int]`): The list of token ids.
            token_type_ids(obj: `list[int]`): List of sequence pair mask.
            label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
        """
    encode_inputs = tokenizer(text=example['text'], max_seq_len = max_seq_length)
    input_ids = encode_inputs['input_ids']
    token_type_ids = encode_inputs['token_type_ids']

    if not is_test:
        label = np.array([example['label']], dtype='int64')
        return input_ids, token_type_ids, label
    else:
        return input_ids, token_type_ids

def predict(model, data, tokenizer, label_map, batch_size=1):
    """
    Predicts the data labels.
    :param model: (obj:`paddle.nn.Layer`): A model to classify texts.
    :param data:(obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
            A Example object contains `text`(word_ids) and `se_len`(sequence length).
    :param tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
            which contains most of the methods. Users should refer to the superclass for more information regarding methods.
    :param label_map(obj:`dict`): The label id (key) to label str (value) map.
    :param batch_size(obj:`int`, defaults to 1): The number of batch.
    :return: results(obj:`dict`): All the predictions labels.
    """

    examples = []
    for text in data:
        input_ids, segment_ids = convert_example(
            text,
            tokenizer,
            max_seq_length=128,
            is_test=True
        )
        examples.append((input_ids, segment_ids))

    batchify_fn = lambda samples, fn = Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id), # input id
        Pad(axis=0,pad_val=tokenizer.pad_token_id), # segment id
    ) : fn(samples)

    # Seperates data into some batches
    batches = []
    one_batch = []
    for example in examples:
        one_batch.append(example)
        if len(one_batch) == batch_size:
            batches.append(one_batch)
            one_batch = []
    if one_batch:
        # The last batch whose size is less than the config batch_size setting.
        batches.append(one_batch)

    results = []
    model.eval()
    for batch in batches:
        input_ids, segment_ids = batchify_fn(batch)
        input_ids = paddle.to_tensor(input_ids)
        segment_ids = paddle.to_tensor(segment_ids)
        logits = model(input_ids, segment_ids)
        probs = F.softmax(logging, axis=1)
        idx = paddle.argmax(probs, axis=1).numpy()
        idx = idx.tolist()
        labels = [label_map[i] for i in idx]
        results.extend(labels)
        return results

@paddle.no_grad()
def evaluate(model,criterion,metric,data_loader):
    """
    Given a dataset, it evals model and computes the metric.

    Args:
        model(obj:`paddle.nn.Layer`): A model to classify texts.
        data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
        criterion(obj:`paddle.nn.Layer`): It can compute the loss.
        metric(obj:`paddle.metric.Metric`): The evaluation metric.
    """
    model.eval()
    metric.reset()
    losses = []
    for batch in data_loader:
        input_ids, token_type_ids, labels = batch
        logits = model(input_ids, token_type_ids)
        loss = criterion(logits, labels)
        losses.append(loss.numpy())
        correct = metric.compute(logits, labels)
        metric.update(correct)
        accu = metric.accumulate()
    print('eval loss %.5f' %(np.mean(losses), accu))
    model.train()
    metric.reset()

def create_dataloader(
        datasets,
        mode='train',
        batch_size=1,
        batchify_fn=None,
        trans_fn=None
):
    if trans_fn:
        datasets = datasets.map(trans_fn)
        shuffle = True if mode == 'train' else False
        if mode == 'train':
            batch_sampler = paddle.io.DistributedBatchSampler(
                datasets, batch_size=batch_size, shuffle=shuffle
            )
        else:
            batch_sampler = paddle.io.BatchSampler(
                datasets,batch_size=batch_size,shuffle=shuffle
            )
    return paddle.io.DataLoader(
        dataset=datasets,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True
    )

